In [2]:
import numpy as np 
import pandas as pd
from tqdm import tqdm
import cv2
import os
import json
import glob
import random
import collections
import shutil
import itertools
import pydicom
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

# import plotly.graph_objs as go
# from plotly.offline import init_notebook_mode, iplot
# from plotly import tools

# from keras.preprocessing.image import ImageDataGenerator
# from keras.applications.vgg16 import VGG16, preprocess_input
# from keras import layers
# from keras.models import Model, Sequential
# from keras.callbacks import EarlyStopping

# init_notebook_mode(connected=True)
RANDOM_SEED = 42
In [2]:
!mv kaggle.json /home/jupyter/.kaggle/kaggle.json
import kaggle
Warning: Your Kaggle API key is readable by other users on this system! To fix this, you can run 'chmod 600 /home/jupyter/.kaggle/kaggle.json'
In [5]:
!kaggle competitions download -c rsna-miccai-brain-tumor-radiogenomic-classification
Warning: Your Kaggle API key is readable by other users on this system! To fix this, you can run 'chmod 600 /home/jupyter/.kaggle/kaggle.json'
Downloading rsna-miccai-brain-tumor-radiogenomic-classification.zip to /home/jupyter
100%|██████████████████████████████████████▉| 12.3G/12.3G [01:51<00:00, 196MB/s]
100%|███████████████████████████████████████| 12.3G/12.3G [01:51<00:00, 119MB/s]
In [6]:
path_to_zip_file = "rsna-miccai-brain-tumor-radiogenomic-classification.zip"
directory_to_extract_to = "./"
import zipfile
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
    zip_ref.extractall(directory_to_extract_to)
In [17]:
train_df = pd.read_csv("train_labels.csv")
train_df
Out[17]:
BraTS21ID MGMT_value
0 0 1
1 2 1
2 3 0
3 5 1
4 6 1
... ... ...
580 1005 1
581 1007 1
582 1008 1
583 1009 0
584 1010 0

585 rows × 2 columns

In [18]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return data


def visualize_sample(
    brats21id, 
    slice_i,
    mgmt_value,
    types=("FLAIR", "T1w", "T1wCE", "T2w")
):
    plt.figure(figsize=(16, 5))
    patient_path = os.path.join(
        "train/", 
        str(brats21id).zfill(5),
    )
    for i, t in enumerate(types, 1):
        t_paths = sorted(
            glob.glob(os.path.join(patient_path, t, "*")), 
            key=lambda x: int(x[:-4].split("-")[-1]),
        )
        data = load_dicom(t_paths[int(len(t_paths) * slice_i)])
        plt.subplot(1, 4, i)
        plt.imshow(data, cmap="gray")
        plt.title(f"{t}", fontsize=16)
        plt.axis("off")

    plt.suptitle(f"MGMT_value: {mgmt_value}", fontsize=16)
    plt.show()
In [12]:
for i in random.sample(range(train_df.shape[0]), 10):
    _brats21id = train_df.iloc[i]["BraTS21ID"]
    _mgmt_value = train_df.iloc[i]["MGMT_value"]
    visualize_sample(brats21id=_brats21id, mgmt_value=_mgmt_value, slice_i=0.5)
In [35]:
from matplotlib import animation, rc
rc('animation', html='jshtml')


def create_animation(ims):
    fig = plt.figure(figsize=(6, 6))
    plt.axis('off')
    im = plt.imshow(ims[0], cmap="gray")

    def animate_func(i):
        im.set_array(ims[i])
        return [im]

    return animation.FuncAnimation(fig, animate_func, frames = len(ims), interval = 1000//24)
In [36]:
def load_dicom_line(path):
    t_paths = sorted(
        glob.glob(os.path.join(path, "*")), 
        key=lambda x: int(x[:-4].split("-")[-1]),
    )
    images = []
    for filename in t_paths:
        data = load_dicom(filename)
        if data.max() == 0:
            continue
        images.append(data)
        
    return images
In [37]:
images = load_dicom_line("train/00002/FLAIR")
create_animation(images)
Out[37]:
In [38]:
images = load_dicom_line("train/00000/FLAIR")
create_animation(images)
Out[38]:
In [27]:
import re
from pydicom.pixel_data_handlers.util import apply_voi_lut
# if os.path.exists("../input/rsna-miccai-brain-tumor-radiogenomic-classification"):
#     data_directory = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
#     pytorch3dpath = "../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D"
# else:
#     data_directory = '/media/roland/data/kaggle/rsna-miccai-brain-tumor-radiogenomic-classification'
#     pytorch3dpath = "EfficientNet-PyTorch-3D"
    

mri_types = ['FLAIR','T1w','T1wCE','T2w']
SIZE = 256
NUM_IMAGES = 64
def load_dicom_image(path, img_size=SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
        
    if rotate > 0:
        rot_choices = [0, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
        data = cv2.rotate(data, rot_choices[rotate])
        
    data = cv2.resize(data, (img_size, img_size))
    return data


def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train", rotate=0):

    files = sorted(glob.glob(f"{split}/{scan_id}/{mri_type}/*.dcm"), 
               key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
    middle = len(files)//2
    num_imgs2 = num_imgs//2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    img3d = np.stack([load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]).T 
    if img3d.shape[-1] < num_imgs:
        n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
        img3d = np.concatenate((img3d,  n_zero), axis = -1)
        
    if np.min(img3d) < np.max(img3d):
        img3d = img3d - np.min(img3d)
        img3d = img3d / np.max(img3d)
            
    return np.expand_dims(img3d,0)
In [21]:
import os
import pandas as pd
path, dirs, files = next(os.walk("train/"))

pd.DataFrame({
    'sortme':dirs
}).sort_values(by='sortme')
Out[21]:
sortme
84 00000
184 00002
506 00003
13 00005
255 00006
... ...
571 01005
461 01007
199 01008
399 01009
244 01010

585 rows × 1 columns

In [22]:
path, dirs, files = next(os.walk("train"))
# train_df['file']=train_dirs
train_df0=pd.DataFrame({
    'file': dirs
})
train_df0['file1']=pd.to_numeric(train_df0['file'])
train_df=pd.merge(train_df0, train_df, left_on='file1', right_on='BraTS21ID', how='left')
# del train_df['file1_x']
# del train_df['file1_y']
# del train_df['file_x']
# del train_df['file_y']
train_df
Out[22]:
file file1 BraTS21ID MGMT_value
0 00310 310 310 0
1 00239 239 239 0
2 00324 324 324 0
3 00626 626 626 1
4 00589 589 589 0
... ... ... ... ...
580 00162 162 162 0
581 00604 604 604 1
582 00533 533 533 0
583 00837 837 837 0
584 00356 356 356 0

585 rows × 4 columns

In [23]:
file_ret=train_df.set_index('file')
del file_ret['BraTS21ID']
del file_ret['file1']
file_ret
Out[23]:
MGMT_value
file
00310 0
00239 0
00324 0
00626 1
00589 0
... ...
00162 0
00604 1
00533 0
00837 0
00356 0

585 rows × 1 columns

In [39]:
total_files = len(os.listdir('train/'))
total_files
Out[39]:
585
In [ ]:
from IPython.display import clear_output
import os
mri_types = ['FLAIR','T1w','T1wCE','T2w']
X=np.zeros((1,256**2))
y=[]
completion=0
total_files = len(os.listdir('train/'))*4
total_arrlen=0
for type1 in mri_types:
    for i in dirs:
        print(completion/total_files)
        lengt=load_dicom_images_3d(i, img_size=SIZE, mri_type=type1).shape[3]
        total_arrlen=total_arrlen+lengt
        clear_output(wait=True)
        print(total_arrlen)
print(total_arrlen) 
#149760
In [48]:
f = np.memmap('train_df.dat', dtype=np.float32, mode='w+', shape=(total_arrlen, 256, 256))
In [49]:
mri_types = ['FLAIR','T1w','T1wCE','T2w']
currentrow=0
for type1 in mri_types:
    for i in dirs:
        print(completion/total_files)
        lengt=load_dicom_images_3d(i, img_size=SIZE, mri_type=type1).shape[3]
        a=load_dicom_images_3d(i, img_size=SIZE, mri_type=type1).reshape(lengt,256,256)                                                           
        for k in range(0,lengt):
            f[currentrow, :, :] = a[k, :, :]
            currentrow=currentrow+1
        for j in range(0,a.shape[0]):
            y.append(file_ret.loc[i].item())
        print(i)
        clear_output(wait=True)
        completion=completion+1
        
0.9995726495726496
00356
In [33]:
mri_types = ['FLAIR','T1w','T1wCE','T2w']
from IPython.display import clear_output
import os

y=[]
completion=0
total_files = len(os.listdir('train/'))*4
total_arrlen=0
for type1 in mri_types:
    for i in dirs:
        print(completion/total_files)
        lengt=load_dicom_images_3d(i, img_size=SIZE, mri_type=type1).shape[3]                                                         
#         for k in range(0,lengt):
#             f[currentrow, :, :] = a[k, :, :]
#             currentrow=currentrow+1
        for j in range(0,lengt):
            y.append(file_ret.loc[i].item())
        print(i)
        clear_output(wait=True)
        completion=completion+1
        
0.9995726495726496
00356
In [52]:
f.flush()
In [37]:
total_arrlen=149760
X = 
X.shape
Out[37]:
(149760, 256, 256, 1)
In [41]:
y=np.array(y)
y.shape
Out[41]:
(149760,)
In [31]:
# X_train=X[22952:,:,:]
# X_test=X[0:22952,:,:]
# y_train=y[22952:]
# y_test=y[0:22952]
X_train=X_train.reshape(126808,256,256,1)
X_test=X_test.reshape(22952,256,256,1)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
(126808, 256, 256, 1) (22952, 256, 256, 1) (126808,) (22952,)
In [32]:
import tensorflow as tf
import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Dense, Flatten, Activation, Dropout, BatchNormalization
from tensorflow.keras import regularizers


# tensorflow.keras.backend.clear_session()
model_best=keras.models.Sequential()
#CONV2D Layer 1 64 node
model_best.add(Conv2D(64, 3,3, padding='same', kernel_initializer='he_uniform', input_shape=(256,256,1)))
model_best.add(Activation('relu'))
model_best.add(BatchNormalization())
model_best.add(keras.layers.Conv2D(64,3,3, kernel_initializer='he_uniform', padding = "same"))
model_best.add(Activation('relu'))
model_best.add(BatchNormalization())
model_best.add(MaxPooling2D(pool_size=(2,2), padding='same'))
model_best.add(Activation('relu'))
model_best.add(Dropout(0.2))

#CONV2D Layer 2 128 node
model_best.add(keras.layers.Conv2D(128,3,3,kernel_initializer='he_uniform', padding = 'same'))
model_best.add(Activation('relu'))
model_best.add(BatchNormalization())
model_best.add(keras.layers.Conv2D(128,3,3,kernel_initializer='he_uniform', padding = "same"))
model_best.add(Activation('relu'))
model_best.add(BatchNormalization())
model_best.add(MaxPooling2D(pool_size=(2,2), padding='same'))
model_best.add(Dropout(0.4))

# #CONV2D Layer 4 256 node
# model_best.add(keras.layers.Conv2D(256,3,3,kernel_initializer='he_uniform', padding = 'same'))
# model_best.add(Activation('relu'))
# model_best.add(BatchNormalization())
# model_best.add(keras.layers.Conv2D(256,3,3,kernel_initializer='he_uniform', padding = "same"))
# model_best.add(Activation('relu'))
# model_best.add(BatchNormalization())
# model_best.add(MaxPooling2D(pool_size=(2,2), padding='same'))
# model_best.add(Dropout(0.5))

# Dense net
model_best.add(keras.layers.Flatten())
model_best.add(keras.layers.Dense(128,activation='relu', kernel_initializer='he_uniform'))
model_best.add(keras.layers.Dropout(0.5))

# Softmax Classification

model_best.add(keras.layers.Dense(2))
model_best.add(keras.layers.Activation('softmax'))
In [33]:
model_best.compile(optimizer=RMSprop(learning_rate=0.001,decay=1e-6), loss='sparse_categorical_crossentropy', metrics=['accuracy']) 
model_best.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_12 (Conv2D)           (None, 86, 86, 64)        640       
_________________________________________________________________
activation_18 (Activation)   (None, 86, 86, 64)        0         
_________________________________________________________________
batch_normalization_12 (Batc (None, 86, 86, 64)        256       
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 29, 29, 64)        36928     
_________________________________________________________________
activation_19 (Activation)   (None, 29, 29, 64)        0         
_________________________________________________________________
batch_normalization_13 (Batc (None, 29, 29, 64)        256       
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 15, 15, 64)        0         
_________________________________________________________________
activation_20 (Activation)   (None, 15, 15, 64)        0         
_________________________________________________________________
dropout_9 (Dropout)          (None, 15, 15, 64)        0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 5, 5, 128)         73856     
_________________________________________________________________
activation_21 (Activation)   (None, 5, 5, 128)         0         
_________________________________________________________________
batch_normalization_14 (Batc (None, 5, 5, 128)         512       
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 2, 2, 128)         147584    
_________________________________________________________________
activation_22 (Activation)   (None, 2, 2, 128)         0         
_________________________________________________________________
batch_normalization_15 (Batc (None, 2, 2, 128)         512       
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 1, 1, 128)         0         
_________________________________________________________________
dropout_10 (Dropout)         (None, 1, 1, 128)         0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 128)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 128)               16512     
_________________________________________________________________
dropout_11 (Dropout)         (None, 128)               0         
_________________________________________________________________
dense_7 (Dense)              (None, 2)                 258       
_________________________________________________________________
activation_23 (Activation)   (None, 2)                 0         
=================================================================
Total params: 277,314
Trainable params: 276,546
Non-trainable params: 768
_________________________________________________________________
In [34]:
history_best=model_best.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=50, use_multiprocessing=True, workers=8)
Epoch 1/50
3963/3963 [==============================] - 530s 134ms/step - loss: 0.7215 - accuracy: 0.5585 - val_loss: 0.7188 - val_accuracy: 0.5554
Epoch 2/50
3963/3963 [==============================] - 568s 143ms/step - loss: 0.6538 - accuracy: 0.6058 - val_loss: 0.8443 - val_accuracy: 0.5443
Epoch 3/50
2417/3963 [=================>............] - ETA: 3:32 - loss: 0.6239 - accuracy: 0.6332
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_2154/456927656.py in <module>
----> 1 history_best=model_best.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=50, use_multiprocessing=True, workers=8)

/opt/conda/lib/python3.7/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1182                 _r=1):
   1183               callbacks.on_train_batch_begin(step)
-> 1184               tmp_logs = self.train_function(iterator)
   1185               if data_handler.should_sync:
   1186                 context.async_wait()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    915       # In this case we have created variables on the first call, so we run the
    916       # defunned version which is guaranteed to never create variables.
--> 917       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    918     elif self._stateful_fn is not None:
    919       # Release the lock early so that multiple threads can perform the call

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   3038        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   3039     return graph_function._call_flat(
-> 3040         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   3041 
   3042   @property

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1962       # No tape is watching; skip to running the function.
   1963       return self._build_call_outputs(self._inference_function.call(
-> 1964           ctx, args, cancellation_manager=cancellation_manager))
   1965     forward_backward = self._select_forward_and_backward_functions(
   1966         args,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    594               inputs=args,
    595               attrs=attrs,
--> 596               ctx=ctx)
    597         else:
    598           outputs = execute.execute_with_cancellation(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

KeyboardInterrupt: 
In [ ]:
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from tensorflow import keras
import tensorflow as tf

wd = 0.0001
lr = 0.001
batch = 256
epochs = 80
image_size = 65536
patch_size = 1024
patch_dist = 64**2
projection = 64
num_heads = 4
transformer_val = [128, 64]
layers = 8
mlp_head_units = [2048, 1024]
data_augmentation = keras.Sequential(
    [
        keras.layers.experimental.preprocessing.Normalization(),
        keras.layers.experimental.preprocessing.Resizing(image_size, image_size),
        keras.layers.experimental.preprocessing.RandomFlip("horizontal"),
        keras.layers.experimental.preprocessing.RandomRotation(factor=0.02),
        keras.layers.experimental.preprocessing.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
data_augmentation.layers[0].adapt(x_train)
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = keras.layers.Dense(units, activation=tf.nn.gelu)(x)
        x = keras.layers.Dropout(dropout_rate)(x)
    return x

class Patches(keras.layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch, -1, patch_dims])
        return patches
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")
resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")
n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")
In [ ]:
class PatchEncoder(keras.layers.Layer):
    def __init__(self, patch_dist, projection):
        super(PatchEncoder, self).__init__()
        self.patch_dist = patch_dist
        self.projection = keras.layers.Dense(units=projection)
        self.position_embedding = keras.layers.Embedding(
            input_dim=patch_dist, output_dim=projection
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.patch_dist, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

def create_vit_classifier():
    inputs = keras.layers.Input(shape=(256, 256))
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)
    encoded_patches = PatchEncoder(patch_dist, projection)(patches)

    for _ in range(layers):
        x1 = keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection, dropout=0.1
        )(x1, x1)
        x2 = keras.layers.Add()([attention_output, encoded_patches])
        x3 = keras.layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_val, dropout_rate=0.1)
        encoded_patches = keras.layers.Add()([x3, x2])
    representation = keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = keras.layers.Flatten()(representation)
    representation = keras.layers.Dropout(0.5)(representation)
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    logits = keras.layers.Dense(10)(features)
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        lr=lr, weight_decay=wd)
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),],)
    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,)
    history = model.fit(
        x=X,
        y=y_train,
        batch_size=batch,
        epochs=epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],)
    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
    return history

vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)
In [ ]:
pd.DataFrame(history)
In [ ]:
from IPython.display import clear_output
import os
mri_types = ['FLAIR','T1w','T1wCE','T2w']
X=np.zeros((1,256**2))
y=[]
completion=0
total_files = len(os.listdir('train/'))*4
total_arrlen=0
for type1 in mri_types:
    for i in dirs:
        print(completion/total_files)
        lengt=load_dicom_images_3d(i, img_size=SIZE, mri_type=type1, split='test').shape[3]
        total_arrlen=total_arrlen+lengt
        clear_output(wait=True)
        print(total_arrlen)
print(total_arrlen) 
In [ ]:
f = np.memmap('test_df.dat', dtype=np.float32, mode='w+', shape=(total_arrlen, 256, 256))
In [ ]:
mri_types = ['FLAIR','T1w','T1wCE','T2w']
currentrow=0
for type1 in mri_types:
    for i in dirs:
        print(completion/total_files)
        lengt=load_dicom_images_3d(i, img_size=SIZE, mri_type=type1, split='test').shape[3]
        a=load_dicom_images_3d(i, img_size=SIZE, mri_type=type1, split='test').reshape(lengt,256,256)                                                           
        for k in range(0,lengt):
            f[currentrow, :, :] = a[k, :, :]
            currentrow=currentrow+1
        for j in range(0,a.shape[0]):
            y.append(file_ret.loc[i].item())
        print(i)
        clear_output(wait=True)
        completion=completion+1
        
In [ ]:
f.flush()
X_testset = np.memmap('test_df.dat', dtype=np.float32, mode='r', shape=(total_arrlen, 256, 256, 1))
X_testset.shape